This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit d84ad5e585c6573635a388de4a5602fe4229fa06 Author: dalong01.liu <[email protected]> AuthorDate: Mon Apr 20 17:39:25 2020 +0800 [FLINK-17096][table-blink] Mini-batch group aggregation doesn't expire state even if state ttl is enabled This closes #11830 --- .../stream/StreamExecGlobalGroupAggregate.scala | 3 +- .../physical/stream/StreamExecGroupAggregate.scala | 8 +- .../stream/StreamExecGroupTableAggregate.scala | 5 +- .../StreamExecIncrementalGroupAggregate.scala | 3 +- .../harness/GroupAggregateHarnessTest.scala | 8 +- .../harness/TableAggregateHarnessTest.scala | 12 +- .../runtime/dataview/PerKeyStateDataViewStore.java | 13 ++ .../table/runtime/dataview/StateListView.java | 8 +- .../operators/aggregate/GroupAggFunction.java | 46 +++--- .../operators/aggregate/GroupTableAggFunction.java | 41 +++-- .../aggregate/MiniBatchGlobalGroupAggFunction.java | 19 ++- .../aggregate/MiniBatchGroupAggFunction.java | 18 ++- .../MiniBatchIncrementalGroupAggFunction.java | 15 +- .../join/stream/AbstractStreamingJoinOperator.java | 6 +- .../join/stream/StreamingJoinOperator.java | 12 +- .../join/stream/StreamingSemiAntiJoinOperator.java | 8 +- .../operators/aggregate/GroupAggFunctionTest.java | 98 +++++++++++ .../aggregate/GroupAggFunctionTestBase.java | 180 +++++++++++++++++++++ .../aggregate/MiniBatchGroupAggFunctionTest.java | 98 +++++++++++ 19 files changed, 514 insertions(+), 87 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala index 7caa79a..ee007ef 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala @@ -162,7 +162,8 @@ class StreamExecGlobalGroupAggregate( recordEqualiser, globalAccTypes, indexOfCountStar, - generateUpdateBefore) + generateUpdateBefore, + tableConfig.getMinIdleStateRetentionTime) new KeyedMapBundleOperator( aggFunction, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala index c3a44b8..5d1d249 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala @@ -157,20 +157,20 @@ class StreamExecGroupAggregate( accTypes, inputRowType, inputCountIndex, - generateUpdateBefore) + generateUpdateBefore, + tableConfig.getMinIdleStateRetentionTime) new KeyedMapBundleOperator( aggFunction, AggregateUtil.createMiniBatchTrigger(tableConfig)) } else { val aggFunction = new GroupAggFunction( - tableConfig.getMinIdleStateRetentionTime, - tableConfig.getMaxIdleStateRetentionTime, aggsHandler, recordEqualiser, accTypes, inputCountIndex, - generateUpdateBefore) + generateUpdateBefore, + tableConfig.getMinIdleStateRetentionTime) val operator = new KeyedProcessOperator[RowData, RowData, RowData](aggFunction) operator diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregate.scala index 525d328..d808aa8 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregate.scala @@ -138,12 +138,11 @@ class StreamExecGroupTableAggregate( val inputCountIndex = aggInfoList.getIndexOfCountStar val aggFunction = new GroupTableAggFunction( - tableConfig.getMinIdleStateRetentionTime, - tableConfig.getMaxIdleStateRetentionTime, aggsHandler, accTypes, inputCountIndex, - generateUpdateBefore) + generateUpdateBefore, + tableConfig.getMinIdleStateRetentionTime) val operator = new KeyedProcessOperator[RowData, RowData, RowData](aggFunction) val selector = KeySelectorUtil.getRowDataSelector( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala index dd8527d..2d00216 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala @@ -166,7 +166,8 @@ class StreamExecIncrementalGroupAggregate( val aggFunction = new MiniBatchIncrementalGroupAggFunction( partialAggsHandler, finalAggsHandler, - finalKeySelector) + finalKeySelector, + config.getMinIdleStateRetentionTime) val operator = new KeyedMapBundleOperator( aggFunction, diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala index 55bdae1..462d5a1 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala @@ -76,8 +76,8 @@ class GroupAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase( val expectedOutput = new ConcurrentLinkedQueue[Object]() - // register cleanup timer with 3001 - testHarness.setProcessingTime(1) + // set TtlTimeProvider with 1 + testHarness.setStateTtlProcessingTime(1) // insertion testHarness.processElement(binaryRecord(INSERT,"aaa", 1L: JLong)) @@ -101,8 +101,8 @@ class GroupAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase( testHarness.processElement(binaryRecord(INSERT, "ccc", 3L: JLong)) expectedOutput.add(binaryRecord(INSERT, "ccc", 3L: JLong)) - // trigger cleanup timer and register cleanup timer with 6002 - testHarness.setProcessingTime(3002) + // set TtlTimeProvider with 3002 to trigger expired state cleanup + testHarness.setStateTtlProcessingTime(3002) // retract after clean up testHarness.processElement(binaryRecord(UPDATE_BEFORE, "ccc", 3L: JLong)) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/TableAggregateHarnessTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/TableAggregateHarnessTest.scala index 64f4b7a..b6add51 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/TableAggregateHarnessTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/TableAggregateHarnessTest.scala @@ -72,8 +72,8 @@ class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase( testHarness.open() val expectedOutput = new ConcurrentLinkedQueue[Object]() - // register cleanup timer with 3001 - testHarness.setProcessingTime(1) + // set TtlTimeProvider with 1 + testHarness.setStateTtlProcessingTime(1) // input with two columns: key and value testHarness.processElement(insertRecord(1: JInt, 1: JInt)) @@ -104,8 +104,8 @@ class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase( testHarness.processElement(insertRecord(2: JInt, 2: JInt)) expectedOutput.add(insertRecord(2: JInt, 2: JInt, 2: JInt)) - // trigger cleanup timer - testHarness.setProcessingTime(3002) + //set TtlTimeProvider with 3002 to trigger expired state cleanup + testHarness.setStateTtlProcessingTime(3002) testHarness.processElement(insertRecord(1: JInt, 2: JInt)) expectedOutput.add(insertRecord(1: JInt, 2: JInt, 2: JInt)) @@ -136,8 +136,8 @@ class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase( testHarness.open() val expectedOutput = new ConcurrentLinkedQueue[Object]() - // register cleanup timer with 3001 - testHarness.setProcessingTime(1) + // set TtlTimeProvider with 1 + testHarness.setStateTtlProcessingTime(1) // input with two columns: key and value testHarness.processElement(insertRecord(1: JInt)) diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/PerKeyStateDataViewStore.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/PerKeyStateDataViewStore.java index 7ae63fd..5b592b4 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/PerKeyStateDataViewStore.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/PerKeyStateDataViewStore.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; @@ -38,9 +39,15 @@ public final class PerKeyStateDataViewStore implements StateDataViewStore { private static final String NULL_STATE_POSTFIX = "_null_state"; private final RuntimeContext ctx; + private final StateTtlConfig stateTtlConfig; public PerKeyStateDataViewStore(RuntimeContext ctx) { + this(ctx, StateTtlConfig.DISABLED); + } + + public PerKeyStateDataViewStore(RuntimeContext ctx, StateTtlConfig stateTtlConfig) { this.ctx = ctx; + this.stateTtlConfig = stateTtlConfig; } @Override @@ -54,6 +61,9 @@ public final class PerKeyStateDataViewStore implements StateDataViewStore { keySerializer, valueSerializer); + if (stateTtlConfig.isEnabled()) { + mapStateDescriptor.enableTimeToLive(stateTtlConfig); + } final MapState<EK, EV> mapState = ctx.getMapState(mapStateDescriptor); if (supportNullKey) { @@ -75,6 +85,9 @@ public final class PerKeyStateDataViewStore implements StateDataViewStore { stateName, elementSerializer); + if (stateTtlConfig.isEnabled()) { + listStateDescriptor.enableTimeToLive(stateTtlConfig); + } final ListState<EE> listState = ctx.getListState(listStateDescriptor); return new StateListView.KeyedStateListView<>(listState); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java index 9ad25a4..782fa06 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java @@ -25,6 +25,7 @@ import org.apache.flink.table.api.dataview.ListView; import java.util.ArrayList; import java.util.Collections; +import java.util.Iterator; import java.util.List; /** @@ -76,7 +77,12 @@ public abstract class StateListView<N, EE> extends ListView<EE> implements State @Override public boolean remove(EE value) throws Exception { - List<EE> list = (List<EE>) getListState().get(); + Iterator<EE> iterator = getListState().get().iterator(); + List<EE> list = new ArrayList<>(); + while (iterator.hasNext()) { + EE it = iterator.next(); + list.add(it); + } boolean success = list.remove(value); if (success) { getListState().update(list); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java index 633b4bb..247ae47 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java @@ -18,13 +18,16 @@ package org.apache.flink.table.runtime.operators.aggregate; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; +import org.apache.flink.table.data.JoinedRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.utils.JoinedRowData; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; -import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; import org.apache.flink.table.runtime.generated.AggsHandleFunction; import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; @@ -36,11 +39,12 @@ import org.apache.flink.util.Collector; import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg; import static org.apache.flink.table.data.util.RowDataUtil.isRetractMsg; +import static org.apache.flink.table.runtime.util.StateTtlConfigUtil.createTtlConfig; /** * Aggregate Function used for the groupby (without window) aggregate. */ -public class GroupAggFunction extends KeyedProcessFunctionWithCleanupState<RowData, RowData, RowData> { +public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, RowData> { private static final long serialVersionUID = -4767158666069797704L; @@ -70,6 +74,11 @@ public class GroupAggFunction extends KeyedProcessFunctionWithCleanupState<RowDa private final boolean generateUpdateBefore; /** + * State idle retention time which unit is MILLISECONDS. + */ + private final long stateRetentionTime; + + /** * Reused output row. */ private transient JoinedRowData resultRow = null; @@ -86,8 +95,6 @@ public class GroupAggFunction extends KeyedProcessFunctionWithCleanupState<RowDa /** * Creates a {@link GroupAggFunction}. * - * @param minRetentionTime minimal state idle retention time. - * @param maxRetentionTime maximal state idle retention time. * @param genAggsHandler The code generated function used to handle aggregates. * @param genRecordEqualiser The code generated equaliser used to equal RowData. * @param accTypes The accumulator types. @@ -95,49 +102,46 @@ public class GroupAggFunction extends KeyedProcessFunctionWithCleanupState<RowDa * -1 when the input doesn't contain COUNT(*), i.e. doesn't contain retraction messages. * We make sure there is a COUNT(*) if input stream contains retraction. * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages. + * @param stateRetentionTime state idle retention time which unit is MILLISECONDS. */ public GroupAggFunction( - long minRetentionTime, - long maxRetentionTime, GeneratedAggsHandleFunction genAggsHandler, GeneratedRecordEqualiser genRecordEqualiser, LogicalType[] accTypes, int indexOfCountStar, - boolean generateUpdateBefore) { - super(minRetentionTime, maxRetentionTime); + boolean generateUpdateBefore, + long stateRetentionTime) { this.genAggsHandler = genAggsHandler; this.genRecordEqualiser = genRecordEqualiser; this.accTypes = accTypes; this.recordCounter = RecordCounter.of(indexOfCountStar); this.generateUpdateBefore = generateUpdateBefore; + this.stateRetentionTime = stateRetentionTime; } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); // instantiate function + StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); function = genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader()); - function.open(new PerKeyStateDataViewStore(getRuntimeContext())); + function.open(new PerKeyStateDataViewStore(getRuntimeContext(), ttlConfig)); // instantiate equaliser equaliser = genRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); ValueStateDescriptor<RowData> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); + if (ttlConfig.isEnabled()){ + accDesc.enableTimeToLive(ttlConfig); + } accState = getRuntimeContext().getState(accDesc); - initCleanupTimeState("GroupAggregateCleanupTime"); - resultRow = new JoinedRowData(); } @Override public void processElement(RowData input, Context ctx, Collector<RowData> out) throws Exception { - long currentTime = ctx.timerService().currentProcessingTime(); - // register state-cleanup timer - registerProcessingCleanupTimer(ctx, currentTime); - RowData currentKey = ctx.getCurrentKey(); - boolean firstRow; RowData accumulators = accState.value(); if (null == accumulators) { @@ -180,7 +184,7 @@ public class GroupAggFunction extends KeyedProcessFunctionWithCleanupState<RowDa // if this was not the first row and we have to emit retractions if (!firstRow) { - if (!stateCleaningEnabled && equaliser.equals(prevAggValue, newAggValue)) { + if (stateRetentionTime <= 0 && equaliser.equals(prevAggValue, newAggValue)) { // newRow is the same as before and state cleaning is not enabled. // We do not emit retraction and acc message. // If state cleaning is enabled, we have to emit messages to prevent too early @@ -220,14 +224,6 @@ public class GroupAggFunction extends KeyedProcessFunctionWithCleanupState<RowDa } @Override - public void onTimer(long timestamp, OnTimerContext ctx, Collector<RowData> out) throws Exception { - if (stateCleaningEnabled) { - cleanupState(accState); - function.cleanup(); - } - } - - @Override public void close() throws Exception { if (function != null) { function.close(); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java index bfb6170..ccbbbc2 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java @@ -18,12 +18,13 @@ package org.apache.flink.table.runtime.operators.aggregate; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; -import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState; import org.apache.flink.table.runtime.generated.GeneratedTableAggsHandleFunction; import org.apache.flink.table.runtime.generated.TableAggsHandleFunction; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; @@ -31,11 +32,12 @@ import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.util.Collector; import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg; +import static org.apache.flink.table.runtime.util.StateTtlConfigUtil.createTtlConfig; /** * Aggregate Function used for the groupby (without window) table aggregate. */ -public class GroupTableAggFunction extends KeyedProcessFunctionWithCleanupState<RowData, RowData, RowData> { +public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData, RowData> { private static final long serialVersionUID = 1L; @@ -59,6 +61,11 @@ public class GroupTableAggFunction extends KeyedProcessFunctionWithCleanupState< */ private final boolean generateUpdateBefore; + /** + * State idle retention time which unit is MILLISECONDS. + */ + private final long stateRetentionTime; + // function used to handle all table aggregates private transient TableAggsHandleFunction function = null; @@ -68,49 +75,45 @@ public class GroupTableAggFunction extends KeyedProcessFunctionWithCleanupState< /** * Creates a {@link GroupTableAggFunction}. * - * @param minRetentionTime minimal state idle retention time. - * @param maxRetentionTime maximal state idle retention time. * @param genAggsHandler The code generated function used to handle table aggregates. * @param accTypes The accumulator types. * @param indexOfCountStar The index of COUNT(*) in the aggregates. * -1 when the input doesn't contain COUNT(*), i.e. doesn't contain retraction messages. * We make sure there is a COUNT(*) if input stream contains retraction. * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages. + * @param stateRetentionTime state idle retention time which unit is MILLISECONDS. */ public GroupTableAggFunction( - long minRetentionTime, - long maxRetentionTime, GeneratedTableAggsHandleFunction genAggsHandler, LogicalType[] accTypes, int indexOfCountStar, - boolean generateUpdateBefore) { - super(minRetentionTime, maxRetentionTime); + boolean generateUpdateBefore, + long stateRetentionTime) { this.genAggsHandler = genAggsHandler; this.accTypes = accTypes; this.recordCounter = RecordCounter.of(indexOfCountStar); this.generateUpdateBefore = generateUpdateBefore; + this.stateRetentionTime = stateRetentionTime; } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); // instantiate function + StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); function = genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader()); - function.open(new PerKeyStateDataViewStore(getRuntimeContext())); + function.open(new PerKeyStateDataViewStore(getRuntimeContext(), ttlConfig)); InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); ValueStateDescriptor<RowData> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); + if (ttlConfig.isEnabled()){ + accDesc.enableTimeToLive(ttlConfig); + } accState = getRuntimeContext().getState(accDesc); - - initCleanupTimeState("GroupTableAggregateCleanupTime"); } @Override public void processElement(RowData input, Context ctx, Collector<RowData> out) throws Exception { - long currentTime = ctx.timerService().currentProcessingTime(); - // register state-cleanup timer - registerProcessingCleanupTimer(ctx, currentTime); - RowData currentKey = ctx.getCurrentKey(); boolean firstRow; @@ -155,14 +158,6 @@ public class GroupTableAggFunction extends KeyedProcessFunctionWithCleanupState< } @Override - public void onTimer(long timestamp, OnTimerContext ctx, Collector<RowData> out) throws Exception { - if (stateCleaningEnabled) { - cleanupState(accState); - function.cleanup(); - } - } - - @Override public void close() throws Exception { if (function != null) { function.close(); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGlobalGroupAggFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGlobalGroupAggFunction.java index 55ed320..97dccda 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGlobalGroupAggFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGlobalGroupAggFunction.java @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.operators.aggregate; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.table.data.RowData; @@ -38,6 +39,8 @@ import javax.annotation.Nullable; import java.util.Map; +import static org.apache.flink.table.runtime.util.StateTtlConfigUtil.createTtlConfig; + /** * Aggregate Function used for the global groupby (without window) aggregate in miniBatch mode. */ @@ -77,6 +80,11 @@ public class MiniBatchGlobalGroupAggFunction extends MapBundleFunction<RowData, private final boolean generateUpdateBefore; /** + * State idle retention time which unit is MILLISECONDS. + */ + private final long stateRetentionTime; + + /** * Reused output row. */ private transient JoinedRowData resultRow = new JoinedRowData(); @@ -104,6 +112,7 @@ public class MiniBatchGlobalGroupAggFunction extends MapBundleFunction<RowData, * -1 when the input doesn't contain COUNT(*), i.e. doesn't contain UPDATE_BEFORE or DELETE messages. * We make sure there is a COUNT(*) if input stream contains UPDATE_BEFORE or DELETE messages. * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages. + * @param stateRetentionTime state idle retention time which unit is MILLISECONDS. */ public MiniBatchGlobalGroupAggFunction( GeneratedAggsHandleFunction genLocalAggsHandler, @@ -111,28 +120,34 @@ public class MiniBatchGlobalGroupAggFunction extends MapBundleFunction<RowData, GeneratedRecordEqualiser genRecordEqualiser, LogicalType[] accTypes, int indexOfCountStar, - boolean generateUpdateBefore) { + boolean generateUpdateBefore, + long stateRetentionTime) { this.genLocalAggsHandler = genLocalAggsHandler; this.genGlobalAggsHandler = genGlobalAggsHandler; this.genRecordEqualiser = genRecordEqualiser; this.accTypes = accTypes; this.recordCounter = RecordCounter.of(indexOfCountStar); this.generateUpdateBefore = generateUpdateBefore; + this.stateRetentionTime = stateRetentionTime; } @Override public void open(ExecutionContext ctx) throws Exception { super.open(ctx); + StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); localAgg = genLocalAggsHandler.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); localAgg.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext())); globalAgg = genGlobalAggsHandler.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); - globalAgg.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext())); + globalAgg.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext(), ttlConfig)); equaliser = genRecordEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); ValueStateDescriptor<RowData> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); + if (ttlConfig.isEnabled()){ + accDesc.enableTimeToLive(ttlConfig); + } accState = ctx.getRuntimeContext().getState(accDesc); resultRow = new JoinedRowData(); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunction.java index 246c065..79bdd4a 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunction.java @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.operators.aggregate; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; @@ -44,6 +45,7 @@ import java.util.List; import java.util.Map; import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg; +import static org.apache.flink.table.runtime.util.StateTtlConfigUtil.createTtlConfig; /** * Aggregate Function used for the groupby (without window) aggregate in miniBatch mode. @@ -85,6 +87,11 @@ public class MiniBatchGroupAggFunction extends MapBundleFunction<RowData, List<R private final boolean generateUpdateBefore; /** + * State idle retention time which unit is MILLISECONDS. + */ + private final long stateRetentionTime; + + /** * Reused output row. */ private transient JoinedRowData resultRow = new JoinedRowData(); @@ -112,6 +119,7 @@ public class MiniBatchGroupAggFunction extends MapBundleFunction<RowData, List<R * -1 when the input doesn't contain COUNT(*), i.e. doesn't contain retraction messages. * We make sure there is a COUNT(*) if input stream contains retraction. * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages. + * @param stateRetentionTime state idle retention time which unit is MILLISECONDS. */ public MiniBatchGroupAggFunction( GeneratedAggsHandleFunction genAggsHandler, @@ -119,26 +127,32 @@ public class MiniBatchGroupAggFunction extends MapBundleFunction<RowData, List<R LogicalType[] accTypes, RowType inputType, int indexOfCountStar, - boolean generateUpdateBefore) { + boolean generateUpdateBefore, + long stateRetentionTime) { this.genAggsHandler = genAggsHandler; this.genRecordEqualiser = genRecordEqualiser; this.recordCounter = RecordCounter.of(indexOfCountStar); this.accTypes = accTypes; this.inputType = inputType; this.generateUpdateBefore = generateUpdateBefore; + this.stateRetentionTime = stateRetentionTime; } @Override public void open(ExecutionContext ctx) throws Exception { super.open(ctx); // instantiate function + StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); function = genAggsHandler.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); - function.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext())); + function.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext(), ttlConfig)); // instantiate equaliser equaliser = genRecordEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); ValueStateDescriptor<RowData> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); + if (ttlConfig.isEnabled()){ + accDesc.enableTimeToLive(ttlConfig); + } accState = ctx.getRuntimeContext().getState(accDesc); inputRowSerializer = InternalSerializers.create(inputType); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchIncrementalGroupAggFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchIncrementalGroupAggFunction.java index 5256f92..fbb3d89 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchIncrementalGroupAggFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchIncrementalGroupAggFunction.java @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.operators.aggregate; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.utils.JoinedRowData; @@ -33,6 +34,8 @@ import javax.annotation.Nullable; import java.util.HashMap; import java.util.Map; +import static org.apache.flink.table.runtime.util.StateTtlConfigUtil.createTtlConfig; + /** * Aggregate Function used for the incremental groupby (without window) aggregate in miniBatch mode. */ @@ -56,6 +59,11 @@ public class MiniBatchIncrementalGroupAggFunction extends MapBundleFunction<RowD private final KeySelector<RowData, RowData> finalKeySelector; /** + * State idle retention time which unit is MILLISECONDS. + */ + private final long stateRetentionTime; + + /** * Reused output row. */ private transient JoinedRowData resultRow = new JoinedRowData(); @@ -69,21 +77,24 @@ public class MiniBatchIncrementalGroupAggFunction extends MapBundleFunction<RowD public MiniBatchIncrementalGroupAggFunction( GeneratedAggsHandleFunction genPartialAggsHandler, GeneratedAggsHandleFunction genFinalAggsHandler, - KeySelector<RowData, RowData> finalKeySelector) { + KeySelector<RowData, RowData> finalKeySelector, + long stateRetentionTime) { this.genPartialAggsHandler = genPartialAggsHandler; this.genFinalAggsHandler = genFinalAggsHandler; this.finalKeySelector = finalKeySelector; + this.stateRetentionTime = stateRetentionTime; } @Override public void open(ExecutionContext ctx) throws Exception { super.open(ctx); ClassLoader classLoader = ctx.getRuntimeContext().getUserCodeClassLoader(); + StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); partialAgg = genPartialAggsHandler.newInstance(classLoader); partialAgg.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext())); finalAgg = genFinalAggsHandler.newInstance(classLoader); - finalAgg.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext())); + finalAgg.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext(), ttlConfig)); resultRow = new JoinedRowData(); } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/AbstractStreamingJoinOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/AbstractStreamingJoinOperator.java index b271746..28a8947 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/AbstractStreamingJoinOperator.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/AbstractStreamingJoinOperator.java @@ -75,7 +75,7 @@ public abstract class AbstractStreamingJoinOperator extends AbstractStreamOperat */ private final boolean filterAllNulls; - protected final long minRetentionTime; + protected final long stateRetentionTime; protected transient JoinConditionWithNullFilters joinCondition; protected transient TimestampedCollector<RowData> collector; @@ -87,13 +87,13 @@ public abstract class AbstractStreamingJoinOperator extends AbstractStreamOperat JoinInputSideSpec leftInputSideSpec, JoinInputSideSpec rightInputSideSpec, boolean[] filterNullKeys, - long minRetentionTime) { + long stateRetentionTime) { this.leftType = leftType; this.rightType = rightType; this.generatedJoinCondition = generatedJoinCondition; this.leftInputSideSpec = leftInputSideSpec; this.rightInputSideSpec = rightInputSideSpec; - this.minRetentionTime = minRetentionTime; + this.stateRetentionTime = stateRetentionTime; this.nullFilterKeys = NullAwareJoinHelper.getNullFilterKeys(filterNullKeys); this.nullSafe = nullFilterKeys.length == 0; this.filterAllNulls = nullFilterKeys.length == filterNullKeys.length; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingJoinOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingJoinOperator.java index fe74c3b..5853848 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingJoinOperator.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingJoinOperator.java @@ -62,8 +62,8 @@ public class StreamingJoinOperator extends AbstractStreamingJoinOperator { boolean leftIsOuter, boolean rightIsOuter, boolean[] filterNullKeys, - long minRetentionTime) { - super(leftType, rightType, generatedJoinCondition, leftInputSideSpec, rightInputSideSpec, filterNullKeys, minRetentionTime); + long stateRetentionTime) { + super(leftType, rightType, generatedJoinCondition, leftInputSideSpec, rightInputSideSpec, filterNullKeys, stateRetentionTime); this.leftIsOuter = leftIsOuter; this.rightIsOuter = rightIsOuter; } @@ -83,14 +83,14 @@ public class StreamingJoinOperator extends AbstractStreamingJoinOperator { "left-records", leftInputSideSpec, leftType, - minRetentionTime); + stateRetentionTime); } else { this.leftRecordStateView = JoinRecordStateViews.create( getRuntimeContext(), "left-records", leftInputSideSpec, leftType, - minRetentionTime); + stateRetentionTime); } if (rightIsOuter) { @@ -99,14 +99,14 @@ public class StreamingJoinOperator extends AbstractStreamingJoinOperator { "right-records", rightInputSideSpec, rightType, - minRetentionTime); + stateRetentionTime); } else { this.rightRecordStateView = JoinRecordStateViews.create( getRuntimeContext(), "right-records", rightInputSideSpec, rightType, - minRetentionTime); + stateRetentionTime); } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingSemiAntiJoinOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingSemiAntiJoinOperator.java index 69f5607..752ee65 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingSemiAntiJoinOperator.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingSemiAntiJoinOperator.java @@ -53,8 +53,8 @@ public class StreamingSemiAntiJoinOperator extends AbstractStreamingJoinOperator JoinInputSideSpec leftInputSideSpec, JoinInputSideSpec rightInputSideSpec, boolean[] filterNullKeys, - long minRetentionTime) { - super(leftType, rightType, generatedJoinCondition, leftInputSideSpec, rightInputSideSpec, filterNullKeys, minRetentionTime); + long stateRetentionTime) { + super(leftType, rightType, generatedJoinCondition, leftInputSideSpec, rightInputSideSpec, filterNullKeys, stateRetentionTime); this.isAntiJoin = isAntiJoin; } @@ -67,14 +67,14 @@ public class StreamingSemiAntiJoinOperator extends AbstractStreamingJoinOperator LEFT_RECORDS_STATE_NAME, leftInputSideSpec, leftType, - minRetentionTime); + stateRetentionTime); this.rightRecordStateView = JoinRecordStateViews.create( getRuntimeContext(), RIGHT_RECORDS_STATE_NAME, rightInputSideSpec, rightType, - minRetentionTime); + stateRetentionTime); } /** diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionTest.java new file mode 100644 index 0000000..5910752 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate; + +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.data.RowData; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.insertRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.updateAfterRecord; + +/** + * Tests for {@link GroupAggFunction}. + */ +public class GroupAggFunctionTest extends GroupAggFunctionTestBase { + + private OneInputStreamOperatorTestHarness<RowData, RowData> createTestHarness( + GroupAggFunction aggFunction) throws Exception { + KeyedProcessOperator<RowData, RowData, RowData> operator = new KeyedProcessOperator<>(aggFunction); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, keyType); + } + + private GroupAggFunction createFunction(boolean generateUpdateBefore) { + return new GroupAggFunction( + function, + equaliser, + accTypes, + -1, + generateUpdateBefore, + minTime.toMilliseconds()); + } + + @Test + public void testGroupAggWithStateTtl() throws Exception { + GroupAggFunction groupAggFunction = createFunction(false); + OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = createTestHarness(groupAggFunction); + testHarness.open(); + testHarness.setup(); + + testHarness.processElement(insertRecord("key1", 1, 20L)); + testHarness.processElement(insertRecord("key1", 2, 0L)); + testHarness.processElement(insertRecord("key1", 3, 999L)); + + testHarness.processElement(insertRecord("key2", 1, 3999L)); + testHarness.processElement(insertRecord("key2", 2, 3000L)); + testHarness.processElement(insertRecord("key2", 3, 1000L)); + + //trigger expired state cleanup + testHarness.setStateTtlProcessingTime(20); + testHarness.processElement(insertRecord("key1", 4, 1020L)); + testHarness.processElement(insertRecord("key1", 5, 1290L)); + testHarness.processElement(insertRecord("key1", 6, 1290L)); + + testHarness.processElement(insertRecord("key2", 4, 4999L)); + testHarness.processElement(insertRecord("key2", 5, 6000L)); + testHarness.processElement(insertRecord("key2", 6, 2000L)); + + List<Object> expectedOutput = new ArrayList<>(); + expectedOutput.add(insertRecord("key1", 1L, 1L)); + expectedOutput.add(updateAfterRecord("key1", 3L, 2L)); + expectedOutput.add(updateAfterRecord("key1", 6L, 3L)); + expectedOutput.add(insertRecord("key2", 1L, 1L)); + expectedOutput.add(updateAfterRecord("key2", 3L, 2L)); + expectedOutput.add(updateAfterRecord("key2", 6L, 3L)); + //result doesn`t contain expired record with the same key + expectedOutput.add(insertRecord("key1", 4L, 1L)); + expectedOutput.add(updateAfterRecord("key1", 9L, 2L)); + expectedOutput.add(updateAfterRecord("key1", 15L, 3L)); + expectedOutput.add(insertRecord("key2", 4L, 1L)); + expectedOutput.add(updateAfterRecord("key2", 9L, 2L)); + expectedOutput.add(updateAfterRecord("key2", 15L, 3L)); + + assertor.assertOutputEqualsSorted("output wrong.", expectedOutput, testHarness.getOutput()); + testHarness.close(); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionTestBase.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionTestBase.java new file mode 100644 index 0000000..03aafe8 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionTestBase.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.dataview.StateDataViewStore; +import org.apache.flink.table.runtime.generated.AggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.generated.RecordEqualiser; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.runtime.util.BinaryRowDataKeySelector; +import org.apache.flink.table.runtime.util.GenericRowRecordSortComparator; +import org.apache.flink.table.runtime.util.RowDataHarnessAssertor; +import org.apache.flink.table.runtime.util.RowDataRecordEqualiser; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.VarCharType; + +/** + * Base class of tests for all kinds of GroupAgg. + */ +abstract class GroupAggFunctionTestBase { + + Time minTime = Time.milliseconds(10); + + LogicalType[] inputFieldTypes = new LogicalType[] { + new VarCharType(VarCharType.MAX_LENGTH), + new IntType(), + new BigIntType() }; + + InternalTypeInfo<RowData> outputType = InternalTypeInfo.ofFields( + new VarCharType(VarCharType.MAX_LENGTH), + new BigIntType(), + new BigIntType()); + + LogicalType[] accTypes = new LogicalType[] { new BigIntType(), new BigIntType() }; + BinaryRowDataKeySelector keySelector = new BinaryRowDataKeySelector(new int[]{0}, inputFieldTypes); + TypeInformation<RowData> keyType = keySelector.getProducedType(); + GeneratedRecordEqualiser equaliser = new GeneratedRecordEqualiser("", "", new Object[0]) { + + private static final long serialVersionUID = 1532460173848746788L; + + @Override + public RecordEqualiser newInstance(ClassLoader classLoader) { + return new RowDataRecordEqualiser(); + } + }; + + GeneratedAggsHandleFunction function = + new GeneratedAggsHandleFunction("Function", "", new Object[0]) { + @Override + public AggsHandleFunction newInstance(ClassLoader classLoader) { + return new SumAndCountAgg(); + } + }; + + RowDataHarnessAssertor assertor = new RowDataHarnessAssertor( + outputType.toRowFieldTypes(), + new GenericRowRecordSortComparator(0, new VarCharType(VarCharType.MAX_LENGTH))); + + static final class SumAndCountAgg implements AggsHandleFunction { + private long sum; + private boolean sumIsNull; + private long count; + private boolean countIsNull; + + @Override + public void open(StateDataViewStore store) throws Exception { + } + + @Override + public void setAccumulators(RowData acc) throws Exception { + sumIsNull = acc.isNullAt(0); + if (!sumIsNull) { + sum = acc.getLong(0); + } + + countIsNull = acc.isNullAt(1); + if (!countIsNull) { + count = acc.getLong(1); + } + } + + @Override + public void accumulate(RowData inputRow) throws Exception { + boolean inputIsNull = inputRow.isNullAt(1); + if (!inputIsNull) { + sum += inputRow.getInt(1); + count += 1; + } + } + + @Override + public void retract(RowData inputRow) throws Exception { + boolean inputIsNull = inputRow.isNullAt(1); + if (!inputIsNull) { + sum -= inputRow.getInt(1); + count -= 1; + } + } + + @Override + public void merge(RowData otherAcc) throws Exception { + boolean sumIsNullOther = otherAcc.isNullAt(0); + if (!sumIsNullOther) { + sum += otherAcc.getLong(0); + } + + boolean countIsNullOther = otherAcc.isNullAt(1); + if (!countIsNullOther) { + count += otherAcc.getLong(1); + } + } + + @Override + public void resetAccumulators() throws Exception { + sum = 0L; + count = 0L; + } + + @Override + public RowData getAccumulators() throws Exception { + GenericRowData acc = new GenericRowData(2); + if (!sumIsNull) { + acc.setField(0, sum); + } + + if (!countIsNull) { + acc.setField(1, count); + } + + return acc; + } + + @Override + public RowData createAccumulators() throws Exception { + GenericRowData acc = new GenericRowData(2); + acc.setField(0, 0L); + acc.setField(1, 0L); + return acc; + } + + @Override + public RowData getValue() throws Exception { + return getAccumulators(); + } + + @Override + public void cleanup() throws Exception { + + } + + @Override + public void close() throws Exception { + + } + } + +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunctionTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunctionTest.java new file mode 100644 index 0000000..c1d3980 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunctionTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.operators.bundle.trigger.CountBundleTrigger; +import org.apache.flink.table.types.logical.RowType; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StreamRecordUtils.insertRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.updateAfterRecord; + +/** + * Tests for {@link MiniBatchGroupAggFunction}. + */ +public class MiniBatchGroupAggFunctionTest extends GroupAggFunctionTestBase { + + private OneInputStreamOperatorTestHarness<RowData, RowData> createTestHarness( + MiniBatchGroupAggFunction aggFunction) throws Exception { + CountBundleTrigger<Tuple2<String, String>> trigger = new CountBundleTrigger<>(3); + KeyedMapBundleOperator operator = new KeyedMapBundleOperator(aggFunction, trigger); + return new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, keyType); + } + + private MiniBatchGroupAggFunction createFunction(boolean generateUpdateBefore) throws Exception { + return new MiniBatchGroupAggFunction( + function, + equaliser, + accTypes, + RowType.of(inputFieldTypes), + -1, + false, + minTime.toMilliseconds()); + } + + @Test + public void testMiniBatchGroupAggWithStateTtl() throws Exception { + + MiniBatchGroupAggFunction function = createFunction(false); + OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = createTestHarness(function); + testHarness.open(); + testHarness.setup(); + + testHarness.processElement(insertRecord("key1", 1, 20L)); + testHarness.processElement(insertRecord("key2", 1, 3000L)); + testHarness.processElement(insertRecord("key1", 3, 999L)); + + testHarness.processElement(insertRecord("key1", 2, 500L)); + testHarness.processElement(insertRecord("key2", 2, 3999L)); + testHarness.processElement(insertRecord("key2", 3, 1000L)); + + //trigger expired state cleanup + testHarness.setStateTtlProcessingTime(20); + testHarness.processElement(insertRecord("key1", 4, 1020L)); + testHarness.processElement(insertRecord("key1", 5, 1290L)); + testHarness.processElement(insertRecord("key1", 6, 1290L)); + + testHarness.processElement(insertRecord("key2", 4, 4999L)); + testHarness.processElement(insertRecord("key2", 5, 6000L)); + testHarness.processElement(insertRecord("key2", 6, 2000L)); + + List<Object> expectedOutput = new ArrayList<>(); + expectedOutput.add(insertRecord("key1", 4L, 2L)); + expectedOutput.add(insertRecord("key2", 1L, 1L)); + expectedOutput.add(updateAfterRecord("key1", 6L, 3L)); + expectedOutput.add(updateAfterRecord("key2", 6L, 3L)); + //result doesn`t contain expired record with the same key + expectedOutput.add(insertRecord("key1", 15L, 3L)); + expectedOutput.add(insertRecord("key2", 15L, 3L)); + + assertor.assertOutputEqualsSorted("output wrong.", expectedOutput, testHarness.getOutput()); + testHarness.close(); + } +}
