This is an automated email from the ASF dual-hosted git repository.
xintongsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 74b5e08e Fix prune state corner cases around checkpoints (#603)
74b5e08e is described below
commit 74b5e08e1c4dc3809507a83b8ec8787188c41b5d
Author: Joey Tong <[email protected]>
AuthorDate: Thu May 7 19:34:47 2026 +0800
Fix prune state corner cases around checkpoints (#603)
---
.../runtime/operator/ActionExecutionOperator.java | 26 ++--
.../operator/ActionExecutionOperatorTest.java | 135 ++++++++++++++++++++-
2 files changed, 151 insertions(+), 10 deletions(-)
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 4765d37b..926a5283 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -133,6 +133,8 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker";
private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME =
"messageSequenceNumber";
+ private static final String LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME =
+ "lastCompletedSequenceNumber";
private static final String PENDING_INPUT_EVENT_STATE_NAME =
"pendingInputEvents";
private final AgentPlan agentPlan;
@@ -202,6 +204,7 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
private transient ActionStateStore actionStateStore;
private transient ValueState<Long> sequenceNumberKState;
+ private transient ValueState<Long> lastCompletedSequenceNumberKState;
private transient ListState<Object> recoveryMarkerOpState;
private transient Map<Long, Map<Object, Long>> checkpointIdToSeqNums;
@@ -301,6 +304,11 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
.getState(
new ValueStateDescriptor<>(
MESSAGE_SEQUENCE_NUMBER_STATE_NAME,
Long.class));
+ lastCompletedSequenceNumberKState =
+ getRuntimeContext()
+ .getState(
+ new ValueStateDescriptor<>(
+
LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME, Long.class));
// init agent processing related state
actionTasksKState =
@@ -591,11 +599,11 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
if (currentInputEventFinished) {
// Clean up sensory memory when a single run finished.
actionTask.getRunnerContext().clearSensoryMemory();
+ lastCompletedSequenceNumberKState.update(sequenceNumber);
// Once all sub-events and actions related to the current
InputEvent are completed,
// we can proceed to process the next InputEvent.
int removedCount =
removeFromListState(currentProcessingKeysOpState, key);
- maybePruneState(key, sequenceNumber);
checkState(
removedCount == 1,
"Current processing key count for key "
@@ -859,8 +867,14 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
.applyToAllKeys(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
- new
ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class),
- (key, state) -> keyToSeqNum.put(key, state.value()));
+ new ValueStateDescriptor<>(
+ LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME,
Long.class),
+ (key, state) -> {
+ Long completedSequenceNumber = state.value();
+ if (completedSequenceNumber != null) {
+ keyToSeqNum.put(key, completedSequenceNumber);
+ }
+ });
checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum);
super.snapshotState(context);
@@ -1139,12 +1153,6 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
}
}
- private void maybePruneState(Object key, long sequenceNum) throws
Exception {
- if (actionStateStore != null) {
- actionStateStore.pruneState(key, sequenceNum);
- }
- }
-
private void processEligibleWatermarks() throws Exception {
Watermark mark = keySegmentQueue.popOldestWatermark();
while (mark != null) {
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index c02b3af0..2fe2d486 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -309,6 +309,62 @@ public class ActionExecutionOperatorTest {
}
}
+ @Test
+ void testDoesNotPruneBeforeCheckpointComplete() throws Exception {
+ AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
+ RecordingActionStateStore actionStateStore = new
RecordingActionStateStore();
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(
+ agentPlanWithStateStore, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ testHarness.processElement(new StreamRecord<>(5L));
+ operator.waitInFlightEventsFinished();
+ assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();
+
+ testHarness.snapshot(1L, 1L);
+ assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();
+ testHarness.notifyOfCompletedCheckpoint(1L);
+
+
assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
+ }
+ }
+
+ @Test
+ void testDoesNotPruneSeqsInFlight() throws Exception {
+ AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
+ RecordingActionStateStore actionStateStore = new
RecordingActionStateStore();
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(
+ agentPlanWithStateStore, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ testHarness.processElement(new StreamRecord<>(5L));
+ operator.waitInFlightEventsFinished();
+ actionStateStore.clearPruneCalls();
+
+ testHarness.processElement(new StreamRecord<>(5L));
+ assertThat(testHarness.getTaskMailbox().size()).isEqualTo(1);
+
+ testHarness.snapshot(1L, 1L);
+ testHarness.notifyOfCompletedCheckpoint(1L);
+
+
assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
+ }
+ }
+
@Test
void testEventLogBaseDirFromAgentConfig() throws Exception {
String baseLogDir = "/tmp/flink-agents-test";
@@ -473,7 +529,7 @@ public class ActionExecutionOperatorTest {
}
@Test
- void testActionStateStoreCleanupAfterOutputEvent() throws Exception {
+ void testActionStateStoreCleanupAfterCheckpointComplete() throws Exception
{
AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
@@ -508,10 +564,66 @@ public class ActionExecutionOperatorTest {
actionStateStoreField.setAccessible(true);
InMemoryActionStateStore actionStateStore =
(InMemoryActionStateStore)
actionStateStoreField.get(operator);
+ assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+
+ testHarness.snapshot(1L, 1L);
+ testHarness.notifyOfCompletedCheckpoint(1L);
+
assertThat(actionStateStore.getKeyedActionStates()).isEmpty();
}
}
+ @Test
+ void testEarlierCheckpointReplayKeepsDurableState() throws Exception {
+ AgentPlan agentPlan = TestAgent.getDurableSyncAgentPlan();
+ InMemoryActionStateStore actionStateStore = new
InMemoryActionStateStore(true);
+ OperatorSubtaskState snapshot;
+
+ TestAgent.DURABLE_CALL_COUNTER.set(0);
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(agentPlan, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ // Simulate failure recovery from a checkpoint taken before this
input was processed.
+ snapshot = testHarness.snapshot(1L, 1L);
+
+ testHarness.processElement(new StreamRecord<>(7L));
+ operator.waitInFlightEventsFinished();
+
+ assertThat(TestAgent.DURABLE_CALL_COUNTER.get()).isEqualTo(1);
+ assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+ }
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(agentPlan, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.initializeState(snapshot);
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ // Replay the same input after restoring from the earlier
checkpoint.
+ testHarness.processElement(new StreamRecord<>(7L));
+ operator.waitInFlightEventsFinished();
+
+ List<StreamRecord<Object>> recordOutput =
+ (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+ assertThat(recordOutput).hasSize(1);
+ assertThat(recordOutput.get(0).getValue()).isEqualTo(21L);
+ assertThat(TestAgent.DURABLE_CALL_COUNTER.get())
+ .as("Durable supplier should not be re-executed during
replay")
+ .isEqualTo(1);
+ }
+ }
+
@Test
void testActionStateStoreReplayIncurNoFunctionCall() throws Exception {
AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
@@ -1886,6 +1998,27 @@ public class ActionExecutionOperatorTest {
return actionStateStore.get(key, 0L, action, event);
}
+ private static class RecordingActionStateStore extends
InMemoryActionStateStore {
+ private final List<Long> prunedSeqNums = new java.util.ArrayList<>();
+
+ private RecordingActionStateStore() {
+ super(false);
+ }
+
+ @Override
+ public void pruneState(Object key, long seqNum) {
+ prunedSeqNums.add(seqNum);
+ }
+
+ private void clearPruneCalls() {
+ prunedSeqNums.clear();
+ }
+
+ private List<Long> getPrunedSeqNums() {
+ return prunedSeqNums;
+ }
+ }
+
private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int
expectedSize)
throws Exception {
assertThat(mailbox.size()).isEqualTo(expectedSize);