This is an automated email from the ASF dual-hosted git repository.

xintongsong pushed a commit to branch release-0.2
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/release-0.2 by this push:
     new e0a6d511 Fix prune state corner cases around checkpoints (#649)
e0a6d511 is described below

commit e0a6d511ed574f56cd36dd7be12d3cb4fcfefe2e
Author: Joey Tong <[email protected]>
AuthorDate: Tue May 12 14:06:18 2026 +0800

    Fix prune state corner cases around checkpoints (#649)
---
 .../runtime/operator/ActionExecutionOperator.java  |  26 ++--
 .../operator/ActionExecutionOperatorTest.java      | 169 ++++++++++++++++++++-
 2 files changed, 185 insertions(+), 10 deletions(-)

diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 89e35131..91e419bb 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -127,6 +127,8 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker";
     private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = 
"messageSequenceNumber";
+    private static final String LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME =
+            "lastCompletedSequenceNumber";
     private static final String PENDING_INPUT_EVENT_STATE_NAME = 
"pendingInputEvents";
 
     private final AgentPlan agentPlan;
@@ -191,6 +193,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private transient ActionStateStore actionStateStore;
     private transient ValueState<Long> sequenceNumberKState;
+    private transient ValueState<Long> lastCompletedSequenceNumberKState;
     private transient ListState<Object> recoveryMarkerOpState;
     private transient Map<Long, Map<Object, Long>> checkpointIdToSeqNums;
 
@@ -288,6 +291,11 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                         .getState(
                                 new ValueStateDescriptor<>(
                                         MESSAGE_SEQUENCE_NUMBER_STATE_NAME, 
Long.class));
+        lastCompletedSequenceNumberKState =
+                getRuntimeContext()
+                        .getState(
+                                new ValueStateDescriptor<>(
+                                        
LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME, Long.class));
 
         // init agent processing related state
         actionTasksKState =
@@ -578,11 +586,11 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         if (currentInputEventFinished) {
             // Clean up sensory memory when a single run finished.
             actionTask.getRunnerContext().clearSensoryMemory();
+            lastCompletedSequenceNumberKState.update(sequenceNumber);
 
             // Once all sub-events and actions related to the current 
InputEvent are completed,
             // we can proceed to process the next InputEvent.
             int removedCount = 
removeFromListState(currentProcessingKeysOpState, key);
-            maybePruneState(key, sequenceNumber);
             checkState(
                     removedCount == 1,
                     "Current processing key count for key "
@@ -789,8 +797,14 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                 .applyToAllKeys(
                         VoidNamespace.INSTANCE,
                         VoidNamespaceSerializer.INSTANCE,
-                        new 
ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class),
-                        (key, state) -> keyToSeqNum.put(key, state.value()));
+                        new ValueStateDescriptor<>(
+                                LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME, 
Long.class),
+                        (key, state) -> {
+                            Long completedSequenceNumber = state.value();
+                            if (completedSequenceNumber != null) {
+                                keyToSeqNum.put(key, completedSequenceNumber);
+                            }
+                        });
         checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum);
 
         super.snapshotState(context);
@@ -1067,12 +1081,6 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         }
     }
 
-    private void maybePruneState(Object key, long sequenceNum) throws 
Exception {
-        if (actionStateStore != null) {
-            actionStateStore.pruneState(key, sequenceNum);
-        }
-    }
-
     private void processEligibleWatermarks() throws Exception {
         Watermark mark = keySegmentQueue.popOldestWatermark();
         while (mark != null) {
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index a45a52d4..729f106e 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -29,6 +29,7 @@ import org.apache.flink.agents.plan.AgentPlan;
 import org.apache.flink.agents.plan.JavaFunction;
 import org.apache.flink.agents.plan.actions.Action;
 import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.CallResult;
 import org.apache.flink.agents.runtime.actionstate.InMemoryActionStateStore;
 import org.apache.flink.agents.runtime.eventlog.FileEventLogger;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -297,6 +298,62 @@ public class ActionExecutionOperatorTest {
         }
     }
 
+    @Test
+    void testDoesNotPruneBeforeCheckpointComplete() throws Exception {
+        AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
+        RecordingActionStateStore actionStateStore = new 
RecordingActionStateStore();
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(
+                                agentPlanWithStateStore, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(5L));
+            operator.waitInFlightEventsFinished();
+            assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();
+
+            testHarness.snapshot(1L, 1L);
+            assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();
+            testHarness.notifyOfCompletedCheckpoint(1L);
+
+            
assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
+        }
+    }
+
+    @Test
+    void testDoesNotPruneSeqsInFlight() throws Exception {
+        AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
+        RecordingActionStateStore actionStateStore = new 
RecordingActionStateStore();
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(
+                                agentPlanWithStateStore, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(5L));
+            operator.waitInFlightEventsFinished();
+            actionStateStore.clearPruneCalls();
+
+            testHarness.processElement(new StreamRecord<>(5L));
+            assertThat(testHarness.getTaskMailbox().size()).isEqualTo(1);
+
+            testHarness.snapshot(1L, 1L);
+            testHarness.notifyOfCompletedCheckpoint(1L);
+
+            
assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
+        }
+    }
+
     @Test
     void testEventLogBaseDirFromAgentConfig() throws Exception {
         String baseLogDir = "/tmp/flink-agents-test";
@@ -461,7 +518,7 @@ public class ActionExecutionOperatorTest {
     }
 
     @Test
-    void testActionStateStoreCleanupAfterOutputEvent() throws Exception {
+    void testActionStateStoreCleanupAfterCheckpointComplete() throws Exception 
{
         AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
 
         try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
@@ -496,10 +553,66 @@ public class ActionExecutionOperatorTest {
             actionStateStoreField.setAccessible(true);
             InMemoryActionStateStore actionStateStore =
                     (InMemoryActionStateStore) 
actionStateStoreField.get(operator);
+            assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+
+            testHarness.snapshot(1L, 1L);
+            testHarness.notifyOfCompletedCheckpoint(1L);
+
             assertThat(actionStateStore.getKeyedActionStates()).isEmpty();
         }
     }
 
+    @Test
+    void testEarlierCheckpointReplayKeepsDurableState() throws Exception {
+        AgentPlan agentPlan = TestAgent.getDurableSyncAgentPlan();
+        InMemoryActionStateStore actionStateStore = new 
InMemoryActionStateStore(true);
+        OperatorSubtaskState snapshot;
+
+        TestAgent.DURABLE_CALL_COUNTER.set(0);
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            // Simulate failure recovery from a checkpoint taken before this 
input was processed.
+            snapshot = testHarness.snapshot(1L, 1L);
+
+            testHarness.processElement(new StreamRecord<>(7L));
+            operator.waitInFlightEventsFinished();
+
+            assertThat(TestAgent.DURABLE_CALL_COUNTER.get()).isEqualTo(1);
+            assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+        }
+
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.initializeState(snapshot);
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            // Replay the same input after restoring from the earlier 
checkpoint.
+            testHarness.processElement(new StreamRecord<>(7L));
+            operator.waitInFlightEventsFinished();
+
+            List<StreamRecord<Object>> recordOutput =
+                    (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+            assertThat(recordOutput).hasSize(1);
+            assertThat(recordOutput.get(0).getValue()).isEqualTo(21L);
+            assertThat(TestAgent.DURABLE_CALL_COUNTER.get())
+                    .as("Durable supplier should not be re-executed during 
replay")
+                    .isEqualTo(1);
+        }
+    }
+
     @Test
     void testActionStateStoreReplayIncurNoFunctionCall() throws Exception {
         AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
@@ -1524,6 +1637,60 @@ public class ActionExecutionOperatorTest {
         }
     }
 
+    private static ActionState actionStateWithCallResults(CallResult... 
callResults) {
+        ActionState actionState = new ActionState(null);
+        for (CallResult callResult : callResults) {
+            actionState.addCallResult(callResult);
+        }
+        return actionState;
+    }
+
+    private static void seedActionState(
+            InMemoryActionStateStore actionStateStore,
+            long key,
+            long input,
+            AgentPlan agentPlan,
+            String actionName,
+            ActionState actionState)
+            throws Exception {
+        InputEvent event = new InputEvent(input);
+        Action action = agentPlan.getActions().get(actionName);
+        actionStateStore.put(key, 0L, action, event, actionState);
+    }
+
+    private static ActionState getStoredActionState(
+            InMemoryActionStateStore actionStateStore,
+            long key,
+            long input,
+            AgentPlan agentPlan,
+            String actionName)
+            throws Exception {
+        InputEvent event = new InputEvent(input);
+        Action action = agentPlan.getActions().get(actionName);
+        return actionStateStore.get(key, 0L, action, event);
+    }
+
+    private static class RecordingActionStateStore extends 
InMemoryActionStateStore {
+        private final List<Long> prunedSeqNums = new java.util.ArrayList<>();
+
+        private RecordingActionStateStore() {
+            super(false);
+        }
+
+        @Override
+        public void pruneState(Object key, long seqNum) {
+            prunedSeqNums.add(seqNum);
+        }
+
+        private void clearPruneCalls() {
+            prunedSeqNums.clear();
+        }
+
+        private List<Long> getPrunedSeqNums() {
+            return prunedSeqNums;
+        }
+    }
+
     private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int 
expectedSize)
             throws Exception {
         assertThat(mailbox.size()).isEqualTo(expectedSize);

Reply via email to