This is an automated email from the ASF dual-hosted git repository.

xintongsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new 74b5e08e Fix prune state corner cases around checkpoints (#603)
74b5e08e is described below

commit 74b5e08e1c4dc3809507a83b8ec8787188c41b5d
Author: Joey Tong <[email protected]>
AuthorDate: Thu May 7 19:34:47 2026 +0800

    Fix prune state corner cases around checkpoints (#603)
---
 .../runtime/operator/ActionExecutionOperator.java  |  26 ++--
 .../operator/ActionExecutionOperatorTest.java      | 135 ++++++++++++++++++++-
 2 files changed, 151 insertions(+), 10 deletions(-)

diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 4765d37b..926a5283 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -133,6 +133,8 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker";
     private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = 
"messageSequenceNumber";
+    private static final String LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME =
+            "lastCompletedSequenceNumber";
     private static final String PENDING_INPUT_EVENT_STATE_NAME = 
"pendingInputEvents";
 
     private final AgentPlan agentPlan;
@@ -202,6 +204,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private transient ActionStateStore actionStateStore;
     private transient ValueState<Long> sequenceNumberKState;
+    private transient ValueState<Long> lastCompletedSequenceNumberKState;
     private transient ListState<Object> recoveryMarkerOpState;
     private transient Map<Long, Map<Object, Long>> checkpointIdToSeqNums;
 
@@ -301,6 +304,11 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                         .getState(
                                 new ValueStateDescriptor<>(
                                         MESSAGE_SEQUENCE_NUMBER_STATE_NAME, 
Long.class));
+        lastCompletedSequenceNumberKState =
+                getRuntimeContext()
+                        .getState(
+                                new ValueStateDescriptor<>(
+                                        
LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME, Long.class));
 
         // init agent processing related state
         actionTasksKState =
@@ -591,11 +599,11 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         if (currentInputEventFinished) {
             // Clean up sensory memory when a single run finished.
             actionTask.getRunnerContext().clearSensoryMemory();
+            lastCompletedSequenceNumberKState.update(sequenceNumber);
 
             // Once all sub-events and actions related to the current 
InputEvent are completed,
             // we can proceed to process the next InputEvent.
             int removedCount = 
removeFromListState(currentProcessingKeysOpState, key);
-            maybePruneState(key, sequenceNumber);
             checkState(
                     removedCount == 1,
                     "Current processing key count for key "
@@ -859,8 +867,14 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                 .applyToAllKeys(
                         VoidNamespace.INSTANCE,
                         VoidNamespaceSerializer.INSTANCE,
-                        new 
ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class),
-                        (key, state) -> keyToSeqNum.put(key, state.value()));
+                        new ValueStateDescriptor<>(
+                                LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME, 
Long.class),
+                        (key, state) -> {
+                            Long completedSequenceNumber = state.value();
+                            if (completedSequenceNumber != null) {
+                                keyToSeqNum.put(key, completedSequenceNumber);
+                            }
+                        });
         checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum);
 
         super.snapshotState(context);
@@ -1139,12 +1153,6 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         }
     }
 
-    private void maybePruneState(Object key, long sequenceNum) throws 
Exception {
-        if (actionStateStore != null) {
-            actionStateStore.pruneState(key, sequenceNum);
-        }
-    }
-
     private void processEligibleWatermarks() throws Exception {
         Watermark mark = keySegmentQueue.popOldestWatermark();
         while (mark != null) {
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index c02b3af0..2fe2d486 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -309,6 +309,62 @@ public class ActionExecutionOperatorTest {
         }
     }
 
+    @Test
+    void testDoesNotPruneBeforeCheckpointComplete() throws Exception {
+        AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
+        RecordingActionStateStore actionStateStore = new 
RecordingActionStateStore();
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(
+                                agentPlanWithStateStore, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(5L));
+            operator.waitInFlightEventsFinished();
+            assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();
+
+            testHarness.snapshot(1L, 1L);
+            assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();
+            testHarness.notifyOfCompletedCheckpoint(1L);
+
+            
assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
+        }
+    }
+
+    @Test
+    void testDoesNotPruneSeqsInFlight() throws Exception {
+        AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
+        RecordingActionStateStore actionStateStore = new 
RecordingActionStateStore();
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(
+                                agentPlanWithStateStore, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(5L));
+            operator.waitInFlightEventsFinished();
+            actionStateStore.clearPruneCalls();
+
+            testHarness.processElement(new StreamRecord<>(5L));
+            assertThat(testHarness.getTaskMailbox().size()).isEqualTo(1);
+
+            testHarness.snapshot(1L, 1L);
+            testHarness.notifyOfCompletedCheckpoint(1L);
+
+            
assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
+        }
+    }
+
     @Test
     void testEventLogBaseDirFromAgentConfig() throws Exception {
         String baseLogDir = "/tmp/flink-agents-test";
@@ -473,7 +529,7 @@ public class ActionExecutionOperatorTest {
     }
 
     @Test
-    void testActionStateStoreCleanupAfterOutputEvent() throws Exception {
+    void testActionStateStoreCleanupAfterCheckpointComplete() throws Exception 
{
         AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
 
         try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
@@ -508,10 +564,66 @@ public class ActionExecutionOperatorTest {
             actionStateStoreField.setAccessible(true);
             InMemoryActionStateStore actionStateStore =
                     (InMemoryActionStateStore) 
actionStateStoreField.get(operator);
+            assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+
+            testHarness.snapshot(1L, 1L);
+            testHarness.notifyOfCompletedCheckpoint(1L);
+
             assertThat(actionStateStore.getKeyedActionStates()).isEmpty();
         }
     }
 
+    @Test
+    void testEarlierCheckpointReplayKeepsDurableState() throws Exception {
+        AgentPlan agentPlan = TestAgent.getDurableSyncAgentPlan();
+        InMemoryActionStateStore actionStateStore = new 
InMemoryActionStateStore(true);
+        OperatorSubtaskState snapshot;
+
+        TestAgent.DURABLE_CALL_COUNTER.set(0);
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            // Simulate failure recovery from a checkpoint taken before this 
input was processed.
+            snapshot = testHarness.snapshot(1L, 1L);
+
+            testHarness.processElement(new StreamRecord<>(7L));
+            operator.waitInFlightEventsFinished();
+
+            assertThat(TestAgent.DURABLE_CALL_COUNTER.get()).isEqualTo(1);
+            assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+        }
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.initializeState(snapshot);
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            // Replay the same input after restoring from the earlier 
checkpoint.
+            testHarness.processElement(new StreamRecord<>(7L));
+            operator.waitInFlightEventsFinished();
+
+            List<StreamRecord<Object>> recordOutput =
+                    (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+            assertThat(recordOutput).hasSize(1);
+            assertThat(recordOutput.get(0).getValue()).isEqualTo(21L);
+            assertThat(TestAgent.DURABLE_CALL_COUNTER.get())
+                    .as("Durable supplier should not be re-executed during 
replay")
+                    .isEqualTo(1);
+        }
+    }
+
     @Test
     void testActionStateStoreReplayIncurNoFunctionCall() throws Exception {
         AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
@@ -1886,6 +1998,27 @@ public class ActionExecutionOperatorTest {
         return actionStateStore.get(key, 0L, action, event);
     }
 
+    private static class RecordingActionStateStore extends 
InMemoryActionStateStore {
+        private final List<Long> prunedSeqNums = new java.util.ArrayList<>();
+
+        private RecordingActionStateStore() {
+            super(false);
+        }
+
+        @Override
+        public void pruneState(Object key, long seqNum) {
+            prunedSeqNums.add(seqNum);
+        }
+
+        private void clearPruneCalls() {
+            prunedSeqNums.clear();
+        }
+
+        private List<Long> getPrunedSeqNums() {
+            return prunedSeqNums;
+        }
+    }
+
     private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int 
expectedSize)
             throws Exception {
         assertThat(mailbox.size()).isEqualTo(expectedSize);

Reply via email to